{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Задание 5.1 - Word2Vec\n",
    "\n",
    "В этом задании мы натренируем свои word vectors на очень небольшом датасете.\n",
    "Мы будем использовать самую простую версию word2vec, без negative sampling и других оптимизаций.\n",
    "\n",
    "Перед запуском нужно запустить скрипт `download_data.sh` чтобы скачать данные.\n",
    "\n",
    "Датасет и модель очень небольшие, поэтому это задание можно выполнить и без GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "from torchvision import transforms\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# We'll use Principal Component Analysis (PCA) to visualize word vectors,\n",
    "# so make sure you install dependencies from requirements.txt!\n",
    "from sklearn.decomposition import PCA \n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "\n",
    "class StanfordTreeBank:\n",
    "    '''\n",
    "    Wrapper for accessing Stanford Tree Bank Dataset\n",
    "    https://nlp.stanford.edu/sentiment/treebank.html\n",
    "    \n",
    "    Parses dataset, gives each token and index and provides lookups\n",
    "    from string token to index and back\n",
    "    \n",
    "    Allows to generate random context with sampling strategy described in\n",
    "    word2vec paper:\n",
    "    https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf\n",
    "    '''\n",
    "    def __init__(self):\n",
    "        self.index_by_token = {}\n",
    "        self.token_by_index = []\n",
    "\n",
    "        self.sentences = []\n",
    "\n",
    "        self.token_freq = {}\n",
    "        \n",
    "        self.token_reject_by_index = None\n",
    "\n",
    "    def load_dataset(self, folder):\n",
    "        filename = os.path.join(folder, \"datasetSentences.txt\")\n",
    "\n",
    "        with open(filename, \"r\", encoding=\"latin1\") as f:\n",
    "            l = f.readline() # skip the first line\n",
    "            \n",
    "            for l in f:\n",
    "                splitted_line = l.strip().split()\n",
    "                words = [w.lower() for w in splitted_line[1:]] # First one is a number\n",
    "                    \n",
    "                self.sentences.append(words)\n",
    "                for word in words:\n",
    "                    if word in self.token_freq:\n",
    "                        self.token_freq[word] +=1 \n",
    "                    else:\n",
    "                        index = len(self.token_by_index)\n",
    "                        self.token_freq[word] = 1\n",
    "                        self.index_by_token[word] = index\n",
    "                        self.token_by_index.append(word)\n",
    "        self.compute_token_prob()\n",
    "                        \n",
    "    def compute_token_prob(self):\n",
    "        words_count = np.array([self.token_freq[token] for token in self.token_by_index])\n",
    "        words_freq = words_count / np.sum(words_count)\n",
    "        \n",
    "        # Following sampling strategy from word2vec paper:\n",
    "        # https://papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf\n",
    "        self.token_reject_by_index = 1- np.sqrt(1e-5/words_freq)\n",
    "    \n",
    "    def check_reject(self, word):\n",
    "        return np.random.rand() > self.token_reject_by_index[self.index_by_token[word]]\n",
    "        \n",
    "    def get_random_context(self, context_length=5):\n",
    "        \"\"\"\n",
    "        Returns tuple of center word and list of context words\n",
    "        \"\"\"\n",
    "        sentence_sampled = []\n",
    "        while len(sentence_sampled) <= 2:\n",
    "            sentence_index = np.random.randint(len(self.sentences)) \n",
    "            sentence = self.sentences[sentence_index]\n",
    "            sentence_sampled = [word for word in sentence if self.check_reject(word)]\n",
    "    \n",
    "        center_word_index = np.random.randint(len(sentence_sampled))\n",
    "        \n",
    "        words_before = sentence_sampled[max(center_word_index - context_length//2,0):center_word_index]\n",
    "        words_after = sentence_sampled[center_word_index+1: center_word_index+1+context_length//2]\n",
    "        \n",
    "        return sentence_sampled[center_word_index], words_before+words_after\n",
    "    \n",
    "    def num_tokens(self):\n",
    "        return len(self.token_by_index)\n",
    "        \n",
    "data = StanfordTreeBank()\n",
    "data.load_dataset(\"./stanfordSentimentTreebank/\")\n",
    "\n",
    "print(\"Num tokens:\", data.num_tokens())\n",
    "for i in range(5):\n",
    "    center_word, other_words = data.get_random_context(5)\n",
    "    print(center_word, other_words)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Имплеменируем PyTorch-style Dataset для Word2Vec\n",
    "\n",
    "Этот Dataset должен сгенерировать много случайных контекстов и превратить их в сэмплы для тренировки.\n",
    "\n",
    "Напоминаем, что word2vec модель получает на вход One-hot вектор слова и тренирует простую сеть для предсказания на его основе соседних слов.\n",
    "Из набора слово-контекст создается N сэмплов (где N - количество слов в контексте):\n",
    "\n",
    "Например:\n",
    "\n",
    "Слово: `orders` и контекст: `['love', 'nicest', 'to', '50-year']` создадут 4 сэмпла:\n",
    "- input: `orders`, target: `love`\n",
    "- input: `orders`, target: `nicest`\n",
    "- input: `orders`, target: `to`\n",
    "- input: `orders`, target: `50-year`\n",
    "\n",
    "Все слова на входе и на выходе закодированы через one-hot encoding, с размером вектора равным количеству токенов."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Word2VecPlain(Dataset):\n",
    "    '''\n",
    "    PyTorch Dataset for plain Word2Vec.\n",
    "    Accepts StanfordTreebank as data and is able to generate dataset based on\n",
    "    a number of random contexts\n",
    "    '''\n",
    "    def __init__(self, data, num_contexts=30000):\n",
    "        '''\n",
    "        Initializes Word2VecPlain, but doesn't generate the samples yet\n",
    "        (for that, use generate_dataset)\n",
    "        Arguments:\n",
    "        data - StanfordTreebank instace\n",
    "        num_contexts - number of random contexts to use when generating a dataset\n",
    "        '''\n",
    "        # TODO: Implement what you need for other methods!\n",
    "    \n",
    "    def generate_dataset(self):\n",
    "        '''\n",
    "        Generates dataset samples from random contexts\n",
    "        Note: there will be more samples than contexts because every context\n",
    "        can generate more than one sample\n",
    "        '''\n",
    "        # TODO: Implement generating the dataset\n",
    "        # You should sample num_contexts contexts from the data and turn them into samples\n",
    "        # Note you will have several samples from one context\n",
    "        \n",
    "    def __len__(self):\n",
    "        '''\n",
    "        Returns total number of samples\n",
    "        '''\n",
    "        # TODO: Return the number of samples\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        '''\n",
    "        Returns i-th sample\n",
    "        \n",
    "        Return values:\n",
    "        input_vector - torch.Tensor with one-hot representation of the input vector\n",
    "        output_index - index of the target word (not torch.Tensor!)\n",
    "        '''\n",
    "        # TODO: Generate tuple of 2 return arguments for i-th sample    \n",
    "\n",
    "dataset = Word2VecPlain(data, 10)\n",
    "dataset.generate_dataset()\n",
    "input_vector, target = dataset[3]\n",
    "print(\"Sample - input: %s, target: %s\" % (input_vector, int(target))) # target should be able to convert to int\n",
    "assert isinstance(input_vector, torch.Tensor)\n",
    "assert torch.sum(input_vector) == 1.0\n",
    "assert input_vector.shape[0] == data.num_tokens()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Создаем модель и тренируем ее"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the usual PyTorch structures\n",
    "dataset = Word2VecPlain(data, 30000)\n",
    "dataset.generate_dataset()\n",
    "\n",
    "# We'll be training very small word vectors!\n",
    "wordvec_dim = 10\n",
    "\n",
    "# We can use a standard sequential model for this\n",
    "nn_model = nn.Sequential(\n",
    "            nn.Linear(dataset.num_tokens, wordvec_dim, bias=False),\n",
    "            nn.Linear(wordvec_dim, dataset.num_tokens, bias=False), \n",
    "         )\n",
    "nn_model.type(torch.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_word_vectors(nn_model):\n",
    "    '''\n",
    "    Extracts word vectors from the model\n",
    "    \n",
    "    Returns:\n",
    "    input_vectors: torch.Tensor with dimensions (num_tokens, num_dimensions)\n",
    "    output_vectors: torch.Tensor with dimensions (num_tokens, num_dimensions)\n",
    "    '''\n",
    "    # TODO: Implement extracting word vectors from param weights\n",
    "    # return tuple of input vectors and output vectos \n",
    "    # Hint: you can access weights as Tensors through nn.Linear class attributes\n",
    "\n",
    "untrained_input_vectors, untrained_output_vectors = extract_word_vectors(nn_model)\n",
    "assert untrained_input_vectors.shape == (data.num_tokens(), wordvec_dim)\n",
    "assert untrained_output_vectors.shape == (data.num_tokens(), wordvec_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_model(model, dataset, train_loader, optimizer, scheduler, num_epochs):\n",
    "    '''\n",
    "    Trains plain word2vec using cross-entropy loss and regenerating dataset every epoch\n",
    "    \n",
    "    Returns:\n",
    "    loss_history, train_history\n",
    "    '''\n",
    "    \n",
    "    loss = nn.CrossEntropyLoss().type(torch.FloatTensor)\n",
    "    \n",
    "    loss_history = []\n",
    "    train_history = []\n",
    "    for epoch in range(num_epochs):\n",
    "        model.train() # Enter train mode\n",
    "        \n",
    "        dataset.generate_dataset() # Regenerate dataset every epoch\n",
    "        \n",
    "        # TODO Implement training for this model\n",
    "        # Note we don't have any validation set here because our purpose is the word vectors,\n",
    "        # not the predictive performance of the model\n",
    "        #\n",
    "        # And don't forget to step the learing rate scheduler!  \n",
    "        print(\"Epoch %i, Average loss: %f, Train accuracy: %f\" % (epoch, ave_loss, train_accuracy))\n",
    "        \n",
    "    return loss_history, train_history"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ну и наконец тренировка!\n",
    "\n",
    "Добейтесь значения ошибки меньше **8.0**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Finally, let's train the model!\n",
    "\n",
    "# TODO: We use placeholder values for hyperparameters - you will need to find better values!\n",
    "optimizer = optim.SGD(nn_model.parameters(), lr=1e-1, weight_decay=0)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)\n",
    "train_loader = torch.utils.data.DataLoader(dataset, batch_size=20)\n",
    "\n",
    "loss_history, train_history = train_model(nn_model, dataset, train_loader, optimizer, scheduler, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize training graphs\n",
    "plt.subplot(211)\n",
    "plt.plot(train_history)\n",
    "plt.subplot(212)\n",
    "plt.plot(loss_history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Визуализируем вектора для разного вида слов до и после тренировки\n",
    "\n",
    "В случае успешной тренировки вы должны увидеть как вектора слов разных типов (например, знаков препинания, предлогов и остальных) разделяются семантически.\n",
    "\n",
    "Студенты - в качестве выполненного задания присылайте notebook с диаграммами!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trained_input_vectors, trained_output_vectors = extract_word_vectors(nn_model)\n",
    "assert trained_input_vectors.shape == (data.num_tokens(), wordvec_dim)\n",
    "assert trained_output_vectors.shape == (data.num_tokens(), wordvec_dim)\n",
    "\n",
    "def visualize_vectors(input_vectors, output_vectors, title=''):\n",
    "    full_vectors = torch.cat((input_vectors, output_vectors), 0)\n",
    "    wordvec_embedding = PCA(n_components=2).fit_transform(full_vectors)\n",
    "\n",
    "    # Helpful words form CS244D example\n",
    "    # http://cs224d.stanford.edu/assignment1/index.html\n",
    "    visualize_words = {'green': [\"the\", \"a\", \"an\"], \n",
    "                      'blue': [\",\", \".\", \"?\", \"!\", \"``\", \"''\", \"--\"], \n",
    "                      'brown': [\"good\", \"great\", \"cool\", \"brilliant\", \"wonderful\", \n",
    "                              \"well\", \"amazing\", \"worth\", \"sweet\", \"enjoyable\"],\n",
    "                      'orange': [\"boring\", \"bad\", \"waste\", \"dumb\", \"annoying\", \"stupid\"],\n",
    "                      'red': ['tell', 'told', 'said', 'say', 'says', 'tells', 'goes', 'go', 'went']\n",
    "                     }\n",
    "\n",
    "    plt.figure(figsize=(7,7))\n",
    "    plt.suptitle(title)\n",
    "    for color, words in visualize_words.items():\n",
    "        points = np.array([wordvec_embedding[data.index_by_token[w]] for w in words])\n",
    "        for i, word in enumerate(words):\n",
    "            plt.text(points[i, 0], points[i, 1], word, color=color,horizontalalignment='center')\n",
    "        plt.scatter(points[:, 0], points[:, 1], c=color, alpha=0.3, s=0.5)\n",
    "\n",
    "visualize_vectors(untrained_input_vectors, untrained_output_vectors, \"Untrained word vectors\")\n",
    "visualize_vectors(trained_input_vectors, trained_output_vectors, \"Trained word vectors\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
